import argparse
import torch
import torch.nn as nn
import utils.datasets as dl
from utils.load_trained_model import load_model
import pathlib
import matplotlib as mpl
import pickle
from distutils.util import strtobool
import os

mpl.use('Agg')
import matplotlib.pyplot as plt
from utils.compute_auc import compute_auc
from utils.temperature_wrapper import TemperatureWrapper
import ssl_utils as ssl

def main(gpu, model_name, model_architecture, model_checkpoint, dataset, cifar_subset, od_dataset, od_validation_set,
         unlabeled_samples, od_validation_samples, thresholds_only=False):

    bs = 1024
    exclusion_thresholds = [0.998, 0.997, 0.996, 0.995, 0.994, 0.99, 0.98, 0.97, 0.96, 0.95, 0.9, 0.8]
    id_target_accuracies = exclusion_thresholds

    if len(gpu)==0:
        device_ids = None
        device = torch.device('cpu')
        print('Warning! Computing on CPU')
    elif len(gpu)==1:
        device_ids = None
        device = torch.device('cuda:' + str(gpu[0]))
        bs = bs
    else:
        device_ids = [int(i) for i in gpu]
        device = torch.device('cuda:' + str(device_ids[0]))
        bs = bs * len(device_ids)

    svhn_extra = True

    if model_name is not None:
        print('Using passed model name')
        model_descriptions = [
            (model_architecture, model_name, model_checkpoint, None, False),
        ]
    else:
        model_descriptions = [
        # ('WideResNet28x2', 'CEDA_27-02-2021_20:41:39', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_28-02-2021_10:35:25', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_01-03-2021_10:05:01', 'best', None, False),
        #('WideResNet28x2', 'CEDA_20-09-2021_19:19:17', 'best', None, False),
        # ('WideResNet28x2', 'plain_12-06-2021_06:15:58', 'best', None, False),
        # ('WideResNet28x2', 'plain_14-06-2021_12:31:50', 'best', None, False),

        #('WideResNet34x20', 'CEDA_27-03-2021_19:07:47', '250', None, False),
        

        ('WideResNet28x2', 'plain_30-09-2021_21:33:53', 'best', None, False),

        # ('WideResNet28x2', 'CEDA_14-09-2021_14:50:06', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_16-09-2021_08:24:25', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_22-09-2021_19:12:31', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_24-09-2021_20:46:31', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_27-09-2021_11:30:38', 'best', None, False),
        #
        # ('WideResNet28x2', 'CEDA_14-09-2021_15:02:41', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_16-09-2021_08:23:20', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_22-09-2021_19:12:30', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_24-09-2021_21:08:01', 'best', None, False),
        # ('WideResNet28x2', 'CEDA_27-09-2021_11:00:22', 'best', None, False),


        #('WideResNet28x2', 'plain_09-03-2021_22:30:34', 'best', None, False),
        #('WideResNet28x2', 'plain_11-03-2021_08:45:17', 'best', None, False),
        #('WideResNet28x2', 'plain_13-03-2021_19:54:51', 'best', None, False),

        #('WideResNet28x2', 'plain_08-03-2021_09:51:19', 'best', None, False),
        #('WideResNet28x2', 'plain_09-03-2021_22:33:13', 'best', None, False),
        #('WideResNet28x2', 'plain_11-03-2021_08:45:16', 'best', None, False),
        #('WideResNet28x2', 'plain_13-03-2021_19:54:52', 'best', None, False),

        #('WideResNet28x2', 'CEDA_02-03-2021_17:33:05', 'best', None, False),
        #('WideResNet40x10', 'CEDA_26-11-2020_18:10:51', 'best', None, False),
        #('WideResNet70x16', 'CEDA_09-12-2020_18:47:56', 'final_swa', None, False),

        # ('ResNet50', 'CEDA_17-09-2020_20:22:52', 'best', None, False),
        # ('ResNet50', 'CEDA_18-09-2020_18:25:52', 'best', None, False),
        # ('ResNet50', 'CEDA_20-09-2020_14:29:00', 'best', None, False),
        # ('ResNet50', 'CEDA_24-09-2020_09:32:14', 'best', None, False),
        # ('ResNet50', 'CEDA_29-09-2020_10:27:18', 'best', None, False),


        # ('ResNet50', 'plain_08-09-2020_13:22:12', 'best', None, False),
        # ('ResNet50', 'plain_18-09-2020_11:54:27', 'best', None, False),
        # ('ResNet50', 'plain_19-09-2020_12:30:38', 'best', None, False),
        # ('ResNet50', 'plain_21-09-2020_12:28:32', 'best', None, False),
        # ('ResNet50', 'plain_24-09-2020_09:50:53', 'best', None, False),
        # ('ResNet50', 'plain_18-09-2020_11:54:28', 'best', None, False),
        # ('ResNet50', 'plain_19-09-2020_11:30:21', 'best', None, False),
        # ('ResNet50', 'plain_21-09-2020_12:26:02', 'best', None, False),
        # ('ResNet50', 'plain_24-09-2020_09:49:33', 'best', None, False),

        #('WideResNet40x10', 'CEDA_15-11-2020_22:18:24', 'best_swa', None, False),
                             #('tresnetm', 'CEDA_25-08-2020_10:22:12' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_12-09-2020_07:40:02' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_24-09-2020_22:13:03' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_13-10-2020_11:04:37' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_14-10-2020_06:42:12' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_18-10-2020_14:18:59' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_21-10-2020_03:02:33' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_24-10-2020_15:05:05' , 'best', None, False),

        # ('shakedrop_pyramid272', 'plain_12-09-2020_15:43:25' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_26-10-2020_09:43:31' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_28-10-2020_09:26:30' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_28-10-2020_23:29:06' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_31-10-2020_08:12:59' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_03-11-2020_10:27:50' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_05-11-2020_15:38:21' , 'best', None, False),
        #
        # ('shakedrop_pyramid272', 'plain_26-10-2020_09:43:32' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_28-10-2020_09:25:21' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_28-10-2020_23:28:38' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_31-10-2020_08:13:12' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_02-11-2020_08:44:50' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_04-11-2020_16:44:28' , 'best', None, False),


        #CIFAR100
        # ('ResNet50', 'CEDA_17-09-2020_20:23:07', 'best', None, False),
        # ('ResNet50', 'CEDA_20-09-2020_08:09:19', 'best', None, False),
        # ('ResNet50', 'CEDA_20-09-2020_16:46:47', 'best', None, False),
        # ('ResNet50', 'CEDA_24-09-2020_10:11:52', 'best', None, False),

        # ('ResNet50', 'plain_21-08-2020_08:23:42', 'best', None, False),
        # ('ResNet50', 'plain_18-09-2020_12:02:09', 'best', None, False),
        # ('ResNet50', 'plain_19-09-2020_22:03:46', 'best', None, False),
        # ('ResNet50', 'plain_21-09-2020_16:49:49', 'best', None, False),
        # ('ResNet50', 'plain_29-09-2020_07:58:53', 'best', None, False),
        #
        # ('ResNet50', 'plain_18-09-2020_12:02:03', 'best', None, False),
        # ('ResNet50', 'plain_19-09-2020_22:03:47', 'best', None, False),
        # ('ResNet50', 'plain_21-09-2020_16:49:28', 'best', None, False),
        # ('ResNet50', 'plain_30-09-2020_13:01:54', 'best', None, False),

        # ('shakedrop_pyramid272', 'CEDA_12-09-2020_11:08:13' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_25-09-2020_08:25:36' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_14-10-2020_07:12:17' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_15-10-2020_07:22:28' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_20-10-2020_12:44:38' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_23-10-2020_09:48:00' , 'best', None, False),
        # ('shakedrop_pyramid272', 'CEDA_28-10-2020_00:04:21' , 'best', None, False),
        #
        # ('shakedrop_pyramid272', 'plain_12-09-2020_15:42:09' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_26-10-2020_18:18:27' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_28-10-2020_10:04:53' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_30-10-2020_12:43:06' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_01-11-2020_12:21:34' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_03-11-2020_08:53:59' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_06-11-2020_14:17:48' , 'best', None, False),
        #
        # ('shakedrop_pyramid272', 'plain_26-10-2020_18:18:26' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_28-10-2020_10:03:13' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_30-10-2020_10:02:44' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_01-11-2020_10:27:53' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_03-11-2020_11:25:19' , 'best', None, False),
        # ('shakedrop_pyramid272', 'plain_07-11-2020_10:22:39' , 'best', None, False),

    ]

    for type, folder, checkpoint, temperature, temp in model_descriptions:
        if (dataset == 'cifar10' or dataset == 'cifar100') and 'BiT' in type:
            img_size = 128
        elif dataset in ['pets', 'cars', 'food-101', 'flowers']:
            img_size = 224
        else:
            img_size = 32

        print('\n##########################\n')

        if dataset == 'cifar10':
            #val_loader = ssl.get_CIFAR10TrainValidation(train=False, batch_size=train_bs, shuffle=False, augm_type='none')
            num_classes = 10
            cifar_subset_samples_per_class = cifar_subset // num_classes
            if cifar_subset_samples_per_class <= 0:
                val_set = dl.get_CIFAR10_1(batch_size=bs, shuffle=False, augm_type='none', size=img_size)
            else:
                val_set = ssl.get_CIFAR10_subset('val', cifar_subset_samples_per_class,
                                                 batch_size=bs, shuffle=False, augm_type='none', size=img_size)

            test_set = dl.get_CIFAR10(train=False, batch_size=bs, shuffle=False, augm_type='none', size=img_size)

            if od_validation_set == 'lsun_subset':
                print('Using LSUN OD')
                od_set = ssl.get_LSUN_scenes_subset('val', samples_per_class=od_validation_samples // 10, batch_size=bs,
                                                    shuffle=False, size=img_size, augm_type='none')
            elif od_validation_set == 'uniformNoise':
                print('Using Uniform Noise OD')
                od_set = dl.get_noise_dataset(od_validation_samples, type='uniform', batch_size=bs, size=img_size)
            elif od_validation_set == 'tinyImages_subset':
                print(f'Using Tiny Images subset {od_validation_samples}')
                od_set = ssl.get_80MTinyImages_subset(od_validation_samples, exclude_cifar10_1=True, exclude_cifar=True,
                                                      batch_size=bs, size=img_size, augm_type='none')
            else:
                print('Using CIFAR100 Minus CIFAR10 OD')
                od_set = dl.cifar.get_CIFAR100MinusCIFAR10(train=True, batch_size=bs, samples_per_class=od_validation_samples // 100, size=img_size)

            class_labels = dl.cifar.get_CIFAR10_labels()
        elif dataset == 'cifar100':
            num_classes = 100
            cifar_subset_samples_per_class = cifar_subset // num_classes
            if cifar_subset_samples_per_class <= 0:
                val_set = ssl.get_CIFAR100TrainValidation(train=False, batch_size=bs, shuffle=False, augm_type='none',
                                                          size=img_size)
            else:
                raise NotImplementedError()

            test_set = dl.get_CIFAR100(train=False, batch_size=bs, shuffle=False, augm_type='none', size=img_size)
            if od_validation_samples == 2_000:
                od_set = ssl.get_CIFAR10_subset('train', od_validation_samples // 10, batch_size=bs, augm_type='none', size=img_size)
            elif od_validation_samples == 50_000:
                od_set = dl.cifar.get_CIFAR10MinusCIFAR100(train=True, batch_size=bs, augm_type='none', size=img_size)
            else:
                raise NotImplementedError()

            class_labels = dl.cifar.get_CIFAR100_labels()
        elif dataset in 'svhn':
            val_set = ssl.get_SVHNValidationExtraSplit('validation-split', shuffle=False, batch_size=bs, augm_type='none')
            test_set = dl.get_SVHN('test', shuffle=False, batch_size=bs, augm_type='none')
            od_set = dl.get_CIFAR100(train=True, batch_size=bs, shuffle=False, augm_type='none')
            class_labels = ['0','1','2','3','4','5','6','7','8','9']
            num_classes = 10
        elif dataset == 'imagenet100':
            val_set = ssl.get_ImageNet100_trainVal(train=False, batch_size=bs, shuffle=False, augm_type='none')
            test_set = dl.get_ImageNet100(train=False, batch_size=bs, shuffle=False, augm_type='none')
            od_set = dl.imagenet_subsets.get_ImageNet100OD(train=False, batch_size=bs, shuffle=False, augm_type='none')
            class_labels = dl.imagenet_subsets.get_ImageNet100_labels()
            num_classes = len(class_labels)
        elif dataset == 'pets':
            val_set = dl.get_pets('test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)
            test_set = val_set
            od_set = dl.imagenet_subsets.get_ImageNetWithout('pets', train=False, batch_size=bs, shuffle=False,
                                                             augm_type='none', size=img_size)
            class_labels = dl.pets.get_pets_labels()
            num_classes = len(class_labels)
        elif dataset == 'food-101':
            val_set = dl.get_food_101('val', batch_size=bs, shuffle=False, augm_type='none', size=img_size)
            test_set = val_set
            od_set = dl.imagenet_subsets.get_ImageNetWithout('food-101', train=False, batch_size=bs, shuffle=False,
                                                             augm_type='none', size=img_size)
            class_labels = dl.food_101.get_food_101_labels()
            num_classes = len(class_labels)
        elif dataset == 'cars':
            val_set = dl.get_stanford_cars(False, batch_size=bs, shuffle=False, augm_type='none', size=img_size)
            test_set = val_set
            od_set = dl.imagenet_subsets.get_ImageNetWithout('cars', train=False, batch_size=bs, shuffle=False,
                                                             augm_type='none', size=img_size)
            class_labels = dl.stanford_cars.get_stanford_cars_labels()
            num_classes = len(class_labels)
        elif dataset == 'flowers':
            val_set = dl.get_flowers('test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)
            test_set = val_set
            od_set = dl.imagenet_subsets.get_ImageNetWithout('flowers', train=False, batch_size=bs, shuffle=False,
                                                             augm_type='none', size=img_size)
            class_labels = dl.flowers.get_flowers_labels()
            num_classes = len(class_labels)
        else:
            raise NotImplementedError()

        if od_dataset == 'tinyImages':
            path = ssl.get_dataset_thresholds_dir(dataset)
        else:
            path = os.path.join('DatasetClassifications/',
                                f'{dataset}_{cifar_subset}_{od_dataset}_{unlabeled_samples}',
                                'Thresholds')

        pathlib.Path(path).mkdir(parents=True, exist_ok=True)

        #auc datasets
        if thresholds_only:
            in_loader = None
            out_loaders = None
        else:
            if dataset == 'cifar10':
                in_loader = dl.get_CIFAR10(train=False, batch_size=bs, augm_type='none', size=img_size)

                out_loaders = [
                    ('SVHN', dl.get_SVHN(split='train', batch_size=bs, augm_type='none', size=img_size)),
                    ('CIFAR100', dl.get_CIFAR100(train=False, batch_size=bs, augm_type='none', size=img_size)),
                    ('LSUN_CR', dl.get_LSUN_CR(train=False, batch_size=bs, size=img_size)),
                    ('Flowers', dl.get_flowers(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Food-101N', dl.get_food_101N(split='val', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    #('OpenImages', dl.get_openImages(split='train', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                ]
            elif dataset == 'cifar100':
                in_loader = dl.get_CIFAR100(train=False, batch_size=bs, augm_type='none', size=img_size)
                out_loaders = [
                    ('SVHN', dl.get_SVHN(split='train', batch_size=bs, augm_type='none', size=img_size)),
                    ('CIFAR10', dl.get_CIFAR10(train=False, batch_size=bs, augm_type='none', size=img_size)),
                    ('LSUN_CR', dl.get_LSUN_CR(train=False, batch_size=bs, size=img_size)),
                    ('FGVC-Aircraft', dl.get_fgvc_aircraft(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                ]
            elif dataset == 'svhn':
                in_loader = dl.get_SVHN('test', shuffle=False, batch_size=bs, augm_type='none')
                out_loaders = [
                    ('CIFAR10', dl.get_CIFAR10(train=False, batch_size=bs, augm_type='none')),
                    ('CIFAR100', dl.get_CIFAR10(train=False, batch_size=bs, augm_type='none')),
                    ('LSUN_CR', dl.get_LSUN_CR(train=False, batch_size=bs)),
                    ('FGVC-Aircraft', dl.get_fgvc_aircraft(split='test', batch_size=bs, shuffle=False, augm_type='none', size=32)),
                    ('Flowers', dl.get_flowers(split='test', batch_size=bs, shuffle=False, augm_type='none', size=32)),
                    ('Food-101N', dl.get_food_101N(split='val', batch_size=bs, shuffle=False, augm_type='none', size=32)),
                ]
            elif dataset == 'imagenet100':
                in_loader = test_set
                out_loaders = [
                    ('Flowers', dl.get_flowers(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('ImageNet100OD', dl.imagenet_subsets.get_ImageNet100OD(train=True, batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Food-101N', dl.get_food_101N(split='val', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('FGVC-Aircraft', dl.get_fgvc_aircraft(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                ]
            elif dataset == 'pets':
                in_loader = test_set
                out_loaders = [
                    ('Flowers', dl.get_flowers(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Cars', dl.get_stanford_cars(train=False, batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Food-101N', dl.get_food_101N(split='val', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('FGVC-Aircraft', dl.get_fgvc_aircraft(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                ]
            elif dataset == 'food-101':
                in_loader = test_set
                out_loaders = [
                    ('Flowers', dl.get_flowers(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Cars', dl.get_stanford_cars(train=False, batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Pets', dl.get_pets(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('FGVC-Aircraft', dl.get_fgvc_aircraft(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                ]
            elif dataset == 'cars':
                in_loader = test_set
                out_loaders = [
                    ('Flowers', dl.get_flowers(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Food-101N', dl.get_food_101N(split='val', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Pets', dl.get_pets(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('FGVC-Aircraft', dl.get_fgvc_aircraft(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                ]
            elif dataset == 'flowers':
                in_loader = test_set
                out_loaders = [
                    ('Cars', dl.get_stanford_cars(train=False, batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Food-101N', dl.get_food_101N(split='val', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('Pets', dl.get_pets(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                    ('FGVC-Aircraft', dl.get_fgvc_aircraft(split='test', batch_size=bs, shuffle=False, augm_type='none', size=img_size)),
                ]
            else:
                raise ValueError(f'Dataset {dataset} not supported')

        model = load_model(type, folder, checkpoint, temperature, device, load_temp=temp, dataset=dataset)
        if device_ids is not None and len(device_ids) > 1:
            model = nn.DataParallel(model, device_ids=device_ids)
        model.eval()
        print(f'{path} {folder}')

        #test set accuracy:
        correct = 0

        with torch.no_grad():
            for data, target in test_set:
                data, target = data.to(device), target.to(device)
                out = model(data)
                correct += torch.sum( torch.max(out,dim=1)[1] == target).item()

        test_acc = float(correct) / len(test_set.dataset)
        print(f'Test accuracy {test_acc}')

        temperature = TemperatureWrapper.compute_temperature(model, val_set, device)
        print(f'Temperature {temperature}')
        pickle.dump(temperature, open(os.path.join(path, f'{folder}_temperature.pickle'), 'wb'))
        temperatures = [None, temperature]

        for temperature in temperatures:
            if temperature is not None:
                calibrated_model = TemperatureWrapper(model, T=temperature)
                temperature_postfix = '_T'
                print(f'Using Temperature {temperature}')

            else:
                calibrated_model = model
                temperature_postfix = ''
                print('No Temperature scaling')

            # 4validation set computations
            correct = 0
            idx = 0
            predicted_classes = torch.zeros(len(val_set.dataset), dtype=torch.long)
            gt_classes = torch.zeros(len(val_set.dataset), dtype=torch.long)
            model_probabilities = torch.zeros((len(val_set.dataset), num_classes), dtype=torch.double)
            with torch.no_grad():
                for data, target in val_set:
                    data, target = data.to(device), target.to(device)
                    out = calibrated_model(data)
                    predictions = torch.max(out, dim=1)[1]
                    correct += torch.sum(predictions == target).item()
                    model_probabilities[idx:(idx + data.shape[0]), :] = torch.softmax(out.cpu().double(), dim=1)
                    gt_classes[idx:(idx + data.shape[0])] = target.detach().cpu()
                    predicted_classes[idx:(idx + data.shape[0])] = predictions.detach().cpu()

                    idx += data.shape[0]

            max_confs, _ = torch.max(model_probabilities, dim=1)
            val_acc = float(correct) / len(val_set.dataset)
            print(f'Validation accuracy {val_acc}')

            #id thresholds
            class_thresholds = torch.zeros((len(id_target_accuracies), num_classes))

            possible_class_thresholds = []
            possible_class_accuracies = []

            #id thresholds on val set
            for class_idx in range(num_classes):
                class_prediction_flags = predicted_classes == class_idx
                class_prediction_idcs = torch.nonzero(class_prediction_flags, as_tuple=False).squeeze()
                class_prediction_correct = (predicted_classes == gt_classes)[class_prediction_idcs]
                class_prediction_confs = max_confs[class_prediction_idcs]

                class_confs_sorted, sort_idcs = torch.sort(class_prediction_confs, descending=False)
                class_correct_sorted = class_prediction_correct[sort_idcs]

                class_cumul_acc = torch.zeros_like(class_confs_sorted)
                for idx in range(class_confs_sorted.shape[0]):
                    class_cumul_acc[idx] = float(torch.sum(class_correct_sorted[idx:])) / (class_confs_sorted.shape[0] - idx)

                possible_class_thresholds.append(class_confs_sorted)
                possible_class_accuracies.append(class_cumul_acc)


            for threshold_idx, threshold in enumerate(id_target_accuracies):
                for class_idx in range(num_classes):
                    valid_conf_idcs = torch.nonzero(possible_class_accuracies[class_idx] >= threshold, as_tuple=False).squeeze()
                    if valid_conf_idcs.nelement() == 0:
                        class_thresholds[threshold_idx, class_idx] = possible_class_thresholds[class_idx][torch.argmax(possible_class_accuracies[class_idx])]
                    else:
                        class_thresholds[threshold_idx, class_idx] = torch.min(possible_class_thresholds[class_idx][valid_conf_idcs])

                    #print(f'{threshold} - {class_labels[class_idx]} - {class_thresholds[threshold_idx, class_idx]}')

            out_dict = {}
            for threshold_idx, threshold in enumerate(id_target_accuracies):
                out_dict[f'{threshold:.3f}'] = class_thresholds[threshold_idx,:]

            pickle.dump(out_dict, open(os.path.join(path, f'{folder}_id_thresholds{temperature_postfix}.pickle'), 'wb'))


            #alibi thresholds if no od available
            out_dict = {}

            no_od_available_thresholds = torch.zeros(num_classes)
            for class_idx in range(num_classes):
                class_confs = model_probabilities[(predicted_classes == class_idx) & (gt_classes == class_idx), class_idx]
                no_od_available_thresholds[class_idx] = torch.min(class_confs)

            out_dict[f'lowest_validation'] = no_od_available_thresholds

            if od_set is not None:
                od_calibrated_class_thresholds = torch.zeros((len(exclusion_thresholds), num_classes))
                all_confs = torch.zeros(len(od_set.dataset), dtype=torch.double)
                all_preds = torch.zeros(len(od_set.dataset), dtype=torch.long)
                with torch.no_grad():
                    idx = 0
                    for data, target in od_set:
                        data, target = data.to(device), target.to(device)
                        out = calibrated_model(data)
                        confs, predictions = torch.max(torch.softmax(out, dim=1), dim=1)

                        next_idx = min(len(all_confs), idx + len(target))
                        all_confs[idx:next_idx] = confs
                        all_preds[idx:next_idx] = predictions
                        idx = next_idx

                for class_idx in range(num_classes):
                    class_confs = torch.sort(all_confs[all_preds == class_idx], descending=False)[0]

                    for threshold_idx, threshold in enumerate(exclusion_thresholds):

                        if torch.numel(class_confs) == 0:
                            value = 0.0
                        else:
                            idx = int(threshold * len(class_confs))
                            value = class_confs[idx]

                        od_calibrated_class_thresholds[threshold_idx, class_idx] = value

                for threshold_idx, threshold in enumerate(exclusion_thresholds):
                    out_dict[f'{threshold:.3f}'] = od_calibrated_class_thresholds[threshold_idx,:]

                pickle.dump(out_dict, open(os.path.join(path, f'{folder}_id_thresholds_from_od{temperature_postfix}.pickle'), 'wb'))

                #new combo thresholds
                if od_set is not None:
                    od_confs = torch.zeros(len(od_set.dataset))
                    od_prediction = torch.zeros(len(od_set.dataset),dtype=torch.long)
                    val_confs = torch.zeros(len(val_set.dataset))
                    val_target = torch.zeros(len(val_set.dataset),dtype=torch.long)

                    for i, loader in enumerate([od_set, val_set]):
                        idx = 0
                        with torch.no_grad():

                            for data, target in loader:
                                data, target = data.to(device), target.to(device)
                                out = calibrated_model(data)
                                out_cpu = out.detach().cpu().double()
                                confs, predictions = torch.max(torch.softmax(out, dim=1), dim=1)

                                if i == 0:
                                    next_idx = min(len(od_confs), idx + len(target))
                                    od_confs[idx:next_idx] = confs
                                    od_prediction[idx:next_idx] = predictions
                                else:
                                    next_idx = min(len(val_confs), idx + len(target))
                                    val_confs[idx:next_idx] = confs
                                    val_confs[idx:next_idx] = confs
                                    val_target[idx:next_idx] = target
                                idx = next_idx

                    combo_calibrated_class_thresholds = torch.zeros((len(exclusion_thresholds), num_classes))
                    for class_idx in range(num_classes):
                        od_matching_class = od_prediction == class_idx
                        od_matching_confs = od_confs[od_matching_class]

                        val_matching_class = val_target == class_idx
                        val_matching_confs = val_confs[val_matching_class]

                        confs_merged = torch.cat([od_matching_confs, val_matching_confs])
                        sort_idcs = torch.argsort(confs_merged, descending=True)

                        is_val = torch.cat([torch.zeros(len(od_matching_confs)), torch.ones(len(val_matching_confs))])
                        is_val_sorted = is_val[sort_idcs]
                        confs_sorted = confs_merged[sort_idcs]

                        val_ratio = torch.cumsum(is_val_sorted,dim=0) / torch.arange(1, len(is_val) + 1)

                        for j, threshold in enumerate(exclusion_thresholds):
                            over_threshold = val_ratio >= threshold

                            threshold_idx = 0
                            for k in range(len(over_threshold)):
                                if over_threshold[k]:
                                    threshold_idx = k

                            class_threshold = confs_sorted[threshold_idx]
                            combo_calibrated_class_thresholds[j, class_idx] = class_threshold
            print('Thresholds calculated')

            if not thresholds_only:
                compute_auc(calibrated_model, in_loader, out_loaders, device)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parse arguments.', prefix_chars='-')

    parser.add_argument('--gpu','--list', nargs='+', default=[0],
                        help='GPU indices, if more than 1 parallel modules will be called')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--cifar_subset', type=int, default=4000)
    parser.add_argument('--od_dataset', type=str, default='tinyImages_subset')
    parser.add_argument('--od_validation_set', type=str, default=None)
    parser.add_argument('--unlabeled_samples', type=int, default=1_000_000)
    parser.add_argument('--od_validation_samples', type=int, default=2_000)

    parser.add_argument('--model_architecture', type=str, default='WideResNet28x2')
    parser.add_argument('--model_name', type=str, default=None)
    parser.add_argument('--model_checkpoint', type=str, default='best')
    parser.add_argument('--thresholds_only', type=lambda x: bool(strtobool(x)),
                    default=False)

    hps = parser.parse_args()

    gpu = hps.gpu
    dataset = hps.dataset
    cifar_subset = hps.cifar_subset
    od_dataset = hps.od_dataset
    od_validation_set = hps.od_validation_set if hps.od_validation_set is not None else od_dataset
    unlabeled_samples = hps.unlabeled_samples
    od_validation_samples = hps.od_validation_samples

    model_name = hps.model_name
    model_architecture = hps.model_architecture
    model_checkpoint = hps.model_checkpoint

    thresholds_only = hps.thresholds_only

    main(gpu,model_name, model_architecture, model_checkpoint, dataset, cifar_subset, od_dataset,
         od_validation_set, unlabeled_samples, od_validation_samples, thresholds_only=thresholds_only)
